DL 04L - MLflow Lab(Python)
Loading...

MLflow Lab

Spark Logo Tiny In this lesson you:

  • Add MLflow to your experiments from the Boston Housing Dataset!
  • Create a LambdaCallback
  • Create a UDF to apply your Keras model to a Spark DataFrame

Bonus:

  • Modify your model (and track the parameters) to get the lowest MSE!
%run "../Includes/Classroom-Setup"

Load & Prepare Data

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
 
boston_housing = load_boston()
 
# split 80/20 train-test
X_train, X_test, y_train, y_test = train_test_split(boston_housing.data,
                                                    boston_housing.target,
                                                    test_size=0.2,
                                                    random_state=1)
# Scale features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
 
# Train-Validation Split
X_train_split, X_val, y_train_split, y_val = train_test_split(X_train,
                                                              y_train,
                                                              test_size=0.25,
                                                              random_state=1)

Build_model

Create a build_model() function. Because Keras models are stateful, we want to get a fresh model every time we are trying out a new experiment.

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
tf.random.set_seed(42)
 
def build_model():
  return Sequential([Dense(50, input_dim=13, activation="relu"),
                     Dense(20, activation="relu"),
                     Dense(1, activation="linear")])

Lambda Callback

Let's add EarlyStopping to our network to we stop the training when a monitored metric has stopped improving.

Further, instead of logging all of the attributes of the history object once it has been trained (as in the previous notebook), let's create a LambdaCallback which can log your loss using MLflow at the end of each epoch!

NOTE: on_epoch_end expects two positional arguments: epoch and logs.

# TODO
import mlflow
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LambdaCallback
 
filepath = f"{working_dir}/keras_mlflow.ckpt"
 
checkpointer = ModelCheckpoint(filepath=filepath, verbose=1, save_best_only=True)
earlyStopping = EarlyStopping(monitor="val_loss", min_delta=0.0001, patience=2, mode="auto")
 
mlflowCallback = LambdaCallback(on_epoch_end=lambda epoch, logs: mlflow.log_metric('loss', logs['loss']))

Track Experiments!

Now let's use MLflow to track our experiments. Try changing some of these hyperparameters around.

  • What do you find works best?
  • Log your parameters! You can log multiple parameters as a dictionary with log_params

Helper method to plot our validation vs training loss using matplotlib.

import matplotlib.pyplot as plt
 
def viewModelLoss(history):
  plt.clf()
  plt.semilogy(history.history["loss"], label="train_loss")
  plt.title("model loss")
  plt.ylabel("loss")
  plt.xlabel("epoch")
  plt.semilogy(history.history["val_loss"], label="val_loss")
  plt.legend()
  return plt
# TODO
from mlflow.keras import log_model
 
optimizer = "adam"
loss = "mse" 
epochs = 30
batch_size = 32
 
with mlflow.start_run() as run:
  model = build_model()
  model.compile(optimizer=optimizer, loss=loss, metrics=["mae", "mse"])
  mlflow.log_params({'loss': loss,
                    'optimizer': optimizer,
                    'epochs': epochs,
                    'batch_size': batch_size
                    })
 
  history = model.fit(X_train_split, 
                      y_train_split, 
                      validation_data=(X_val, y_val), 
                      epochs=epochs, 
                      batch_size=batch_size, 
                      callbacks=[checkpointer, earlyStopping, mlflowCallback], 
                      verbose=2)
 
  for i, layer in enumerate(model.layers):
    mlflow.log_param(f"hidden_layer_{i}_units", layer.output_shape)
 
  log_model(model, "keras_model")
 
  fig = viewModelLoss(history)
  fig.savefig("train-validation-loss.png")
  mlflow.log_artifact("train-validation-loss.png")
  plt.show()

User Defined Function

Let's now register our Keras model as a Spark UDF to apply to rows in parallel.

# TODO
import pandas as pd
from mlflow.tracking import MlflowClient
 
client = MlflowClient()
 
runs = pd.DataFrame([(run.start_time, run.artifact_uri) for run in client.list_run_infos(run.info.experiment_id)],
                   columns=["start_time", "artifact_uri"])
last_run = runs.sort_values("start_time", ascending=False).iloc[0]
 
predict = mlflow.pyfunc.spark_udf(spark, last_run.artifact_uri + "/keras_model")
 
X_test_DF = spark.createDataFrame(pd.concat([pd.DataFrame(X_test, columns=boston_housing.feature_names), 
                                             pd.DataFrame(y_test, columns=["label"])], axis=1))
 
display(X_test_DF.withColumn("prediction", predict(*boston_housing.feature_names)))
 
 
CRIM
ZN
INDUS
CHAS
NOX
RM
AGE
DIS
RAD
TAX
PTRATIO
B
LSTAT
label
prediction
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
-0.39933979606900954
0.9231851072122887
-1.3008172324752307
-0.2932942300427058
-0.6940153400273951
0.8428658973752055
0.0828789916878546
-0.3037287749096334
-0.28494153855996196
-1.073665359959251
-0.01782083284459722
0.4347266572189193
-0.7282087302685987
28.2
27.476903915405273
-0.4019548933331044
1.8690374683538244
-1.066896741105703
-0.2932942300427058
-0.5912315464740051
0.6206033641054833
-0.4043647957098699
0.8997423620691084
-0.5156356900193299
-0.1969979758618628
-0.38703193383053436
0.4347266572189193
-0.7767694871316487
23.9
21.728450775146484
-0.379641731913552
-0.49559343450001486
-0.6092897798640649
-0.2932942300427058
-0.8995829271341742
-0.34689236895094994
0.6156923419357262
0.879585377073034
-0.7463298414786979
-1.0085076489790397
-0.24857777096080874
0.3892265137229236
0.8354476407216097
16.6
18.017324447631836
-0.3984016887624673
0.041822679784948635
-0.7320980378330667
-0.2932942300427058
-1.2336302561826913
-0.5677022059247909
-1.6312376416753633
1.261293528544365
-0.6309827657490139
-0.3450836826350702
0.2129361052716143
0.42717975841790073
-0.6491240690916317
22
22.6096248626709
-0.002105536476124092
-0.49559343450001486
1.0266916566515687
-0.2932942300427058
1.858448866548451
-1.3172934945991452
0.99076489770232
-0.8131293020586725
1.6759587488446654
1.5563367923329132
0.8129041443737646
0.2585229765168774
-0.3591469781094192
20.8
23.103315353393555
-0.34014927519776456
-0.49559343450001486
1.580790820583137
-0.2932942300427058
0.6079127116488755
0.14992976424019136
1.0503558645063584
-0.7074821535312378
-0.6309827657490139
0.19987171829033315
1.2744180206061877
0.3128825229531993
-0.23011410987331507
23
22.633604049682617
-0.4007814382192033
2.9438696969237514
-0.895842381791736
-0.2932942300427058
-1.216499623923793
0.5247254085773685
-1.5611306219059067
0.60923630664559
-0.6309827657490139
-0.9374265097279
0.35139026814134156
0.4347266572189193
-1.1208571357612596
27.9
28.41461944580078
-0.25644536456840256
-0.49559343450001486
-0.42946340212374046
-0.2932942300427058
-0.1287044754837508
-0.28587912452396785
1.1239682352642877
0.164696897493287
-0.6309827657490139
-0.5701739569303456
1.1821152453597032
0.21411658647030432
0.03627632777541607
14.5
15.94991683959961
-0.3926219844701186
-0.49559343450001486
0.4111883637354992
3.409545424246455
-0.02592068193036095
-0.46165537632551285
0.9066364739789716
-0.44223133692013117
-0.5156356900193299
-0.7538002333291227
-0.9408485853094433
0.4347266572189193
0.7133520234659414
21.5
22.616025924682617
-0.3996112670281956
-0.49559343450001486
-1.1209908547349063
-0.2932942300427058
-0.5484049658267593
0.21530109755481516
-0.06434574982800471
-0.3464030475945312
-0.8616769172083818
-0.7893408029546926
-0.2947291585840501
0.3831014944061549
-0.5506151051694446
22.6
28.91265106201172
0.2201022007453958
-0.49559343450001486
1.0266916566515687
-0.2932942300427058
-0.18009637226044573
0.6990489640830325
0.24412513715760548
-0.2333917546657201
1.6759587488446654
1.5563367923329132
0.8129041443737646
0.39283590010601926
-0.6990722761507687
23.7
18.955276489257812
-0.4014010050050877
1.8690374683538244
-1.066896741105703
-0.2932942300427058
-0.5912315464740051
0.879183304772218
-1.3963791254476832
1.2459043549876427
-0.5156356900193299
-0.1969979758618628
-0.38703193383053436
0.3370544741854476
-1.133344187526044
31.2
26.317577362060547
-0.3888925427042034
-0.49559343450001486
-0.15606882783560522
-0.2932942300427058
-0.05161663031870841
-0.7812485614192295
-0.09589390872425996
-0.5036464129854542
-0.40028861428964596
0.17025457693569165
-0.2947291585840501
0.37566397095007853
0.11536098895238317
19.3
14.5833158493042
-0.39994513252235586
2.9438696969237514
-1.1268388670191445
-0.2932942300427058
-1.3278487336066316
-0.6998975688499196
-1.2912185957934978
1.314872399669762
-0.6309827657490139
-0.39247110880249664
-1.0793027481791688
0.4347266572189193
-0.35220972712898335
19.4
22.51557159423828
-0.37609290590677275
-0.49559343450001486
-0.17215086161726015
-0.2932942300427058
-0.07731257870705588
-0.3570612430221145
0.7874545403708955
-0.4683362847019325
-0.6309827657490139
-0.5879442417431304
-0.01782083284459722
0.4281641365223818
-0.14686709810808637
19.4
20.485151290893555
0.9034777977300512
-0.49559343450001486
1.0266916566515687
-0.2932942300427058
0.9076987761795959
-0.9599302058125352
1.1239682352642877
-1.1994636468077908
1.6759587488446654
1.5563367923329132
0.8129041443737646
-0.27413495335211524
-0.08998164006851357
27.9
19.97394371032715
1.268748541160705
-0.49559343450001486
1.0266916566515687
-0.2932942300427058
1.0104825697329858
0.5523266381990988
0.8891097190366075
-1.1712816326563706
1.6759587488446654
1.5563367923329132
0.8129041443737646
0.0641629885544131
1.4514755277843001
13.9
14.440753936767578
-0.347058648965436
-0.49559343450001486
-0.7130919979092927
-0.2932942300427058
-0.41992522388502207
3.568124148447082
0.5280585672239052
-0.4398238281735274
-0.16959446283027801
-0.5701739569303456
-0.47933470907702025
0.27175739325489523
-1.1305692871338697
50
44.37289810180664

Showing all 102 rows.